{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Mish Derivatves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-500.3069, -490.6361, -480.3858, -471.2755, -459.0872, -451.1570,\n",
       "        -440.2400, -429.6230, -419.9467, -408.3055, -402.3395, -389.1660,\n",
       "        -380.1614, -369.7649, -359.7261, -348.8759, -338.7170, -329.2680,\n",
       "        -319.7470, -309.6079, -301.3083, -290.9236, -279.3832, -267.8622,\n",
       "        -259.3479, -249.7400, -240.8742, -229.2343, -219.3999, -210.0166,\n",
       "        -199.8259, -191.5603, -178.9595, -171.4488, -160.3362, -150.1327,\n",
       "        -139.2230, -130.8046, -121.8909, -108.4913, -100.5724,  -88.9087,\n",
       "         -79.9365,  -70.3478,  -60.1005,  -49.9595,  -37.6322,  -29.9353,\n",
       "         -18.9407,  -11.9213,   -2.5633,   10.6869,   18.9005,   29.4622,\n",
       "          41.7188,   49.6080,   59.3583,   71.3071,   80.2604,   91.4908,\n",
       "         100.2913,  108.7626,  118.8391,  129.8859,  139.5593,  150.6612,\n",
       "         161.5152,  170.5409,  179.0472,  187.5896,  199.0938,  210.3955,\n",
       "         221.2551,  229.1151,  240.5497,  250.7286,  260.3474,  268.6524,\n",
       "         280.6704,  291.0199,  302.0525,  309.1079,  320.3692,  330.5589,\n",
       "         340.5503,  350.1638,  359.5840,  369.0214,  379.5835,  390.5975,\n",
       "         398.7851,  408.7201,  420.3418,  430.0718,  440.4431,  449.6514,\n",
       "         459.1114,  468.1187,  480.6921,  490.0807])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inp = torch.randn(100) + (torch.arange(0, 1000, 10, dtype=torch.float)-500.)\n",
    "inp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import sympy\n",
    "from sympy import Symbol, Function, Expr, diff, simplify, exp, log, tanh\n",
    "x = Symbol('x')\n",
    "f =  Function('f')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Overall Derivative"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{x \\left(1 - \\tanh^{2}{\\left(\\log{\\left(e^{x} + 1 \\right)} \\right)}\\right) e^{x}}{e^{x} + 1} + \\tanh{\\left(\\log{\\left(e^{x} + 1 \\right)} \\right)}$"
      ],
      "text/plain": [
       "x*(1 - tanh(log(exp(x) + 1))**2)*exp(x)/(exp(x) + 1) + tanh(log(exp(x) + 1))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diff(x*tanh(log(exp(x)+1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle - \\frac{x \\left(\\tanh^{2}{\\left(\\log{\\left(e^{x} + 1 \\right)} \\right)} - 1\\right) e^{x} - \\left(e^{x} + 1\\right) \\tanh{\\left(\\log{\\left(e^{x} + 1 \\right)} \\right)}}{e^{x} + 1}$"
      ],
      "text/plain": [
       "-(x*(tanh(log(exp(x) + 1))**2 - 1)*exp(x) - (exp(x) + 1)*tanh(log(exp(x) + 1)))/(exp(x) + 1)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simplify(diff(x*tanh(log(exp(x)+1))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Softplus\n",
    "\n",
    "$ \\Large \\frac{\\partial}{\\partial x} Softplus(x) = 1 - \\frac{1}{e^{x} + 1} $\n",
    "\n",
    "Or, from PyTorch:\n",
    "\n",
    "$ \\Large \\frac{\\partial}{\\partial x} Softplus(x) = 1 - e^{-Y} $\n",
    "\n",
    "Where $Y$ is saved output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "class SoftPlusTest(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, inp, threshold=20):\n",
    "        y = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)\n",
    "        ctx.save_for_backward(y)\n",
    "        return y\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_out):\n",
    "        y, = ctx.saved_tensors\n",
    "        res = 1 - (-y).exp_()\n",
    "        return grad_out * res\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(F.softplus(inp), SoftPlusTest.apply(inp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.gradcheck(SoftPlusTest.apply, inp.to(torch.float64).requires_grad_())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## $tanh(Softplus(x))$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\left(1 - \\tanh^{2}{\\left(f{\\left(x \\right)} \\right)}\\right) \\frac{d}{d x} f{\\left(x \\right)}$"
      ],
      "text/plain": [
       "(1 - tanh(f(x))**2)*Derivative(f(x), x)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diff(tanh(f(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "class TanhSPTest(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, inp, threshold=20):\n",
    "        ctx.save_for_backward(inp)\n",
    "        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)\n",
    "        y = torch.tanh(sp)\n",
    "        return y\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_out, threshold=20):\n",
    "        inp, = ctx.saved_tensors\n",
    "        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)\n",
    "        grad_sp = 1 - torch.exp(-sp)\n",
    "        tanhsp = torch.tanh(sp)\n",
    "        grad = (1 - tanhsp*tanhsp) * grad_sp\n",
    "        return grad_out * grad\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(TanhSPTest.apply(inp), torch.tanh(F.softplus(inp)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.gradcheck(TanhSPTest.apply, inp.to(torch.float64).requires_grad_())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Mish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle x \\frac{d}{d x} f{\\left(x \\right)} + f{\\left(x \\right)}$"
      ],
      "text/plain": [
       "x*Derivative(f(x), x) + f(x)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diff(x * f(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle x \\left(1 - \\tanh^{2}{\\left(f{\\left(x \\right)} \\right)}\\right) \\frac{d}{d x} f{\\left(x \\right)} + \\tanh{\\left(f{\\left(x \\right)} \\right)}$"
      ],
      "text/plain": [
       "x*(1 - tanh(f(x))**2)*Derivative(f(x), x) + tanh(f(x))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diff(x*tanh(f(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{x \\frac{d}{d x} f{\\left(x \\right)}}{\\cosh^{2}{\\left(f{\\left(x \\right)} \\right)}} + \\tanh{\\left(f{\\left(x \\right)} \\right)}$"
      ],
      "text/plain": [
       "x*Derivative(f(x), x)/cosh(f(x))**2 + tanh(f(x))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simplify(diff(x*tanh(f(x))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\left(1 - \\tanh^{2}{\\left(f{\\left(x \\right)} \\right)}\\right) \\frac{d}{d x} f{\\left(x \\right)}$"
      ],
      "text/plain": [
       "(1 - tanh(f(x))**2)*Derivative(f(x), x)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diff(tanh(f(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "class MishTest(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, inp, threshold=20):\n",
    "        ctx.save_for_backward(inp)\n",
    "        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)\n",
    "        tsp = torch.tanh(sp)\n",
    "        y = inp.mul(tsp)\n",
    "        return y\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_out, threshold=20):\n",
    "        inp, = ctx.saved_tensors\n",
    "        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)\n",
    "        grad_sp = 1 - torch.exp(-sp)\n",
    "        tsp = torch.tanh(sp)\n",
    "        grad_tsp = (1 - tsp*tsp) * grad_sp\n",
    "        grad = inp * grad_tsp + tsp\n",
    "        return grad_out * grad\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(MishTest.apply(inp), inp.mul(torch.tanh(F.softplus(inp))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.gradcheck(TanhSPTest.apply, inp.to(torch.float64).requires_grad_())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-fastai]",
   "language": "python",
   "name": "conda-env-.conda-fastai-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
